# ------------------------------------------------------------------------------
# Trains OOS ensemble models on validation sets
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=10000]" bash 06_train_ensemble.sh {outcome} {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(Matrix)
library(tidyverse)
library(glue)
library(optparse)
library(here) # here() relative filepaths
library(testit) # assert()

u <- modules::use(here("lib", "util.R"))
dir <- here::here("code", "03_analysis", "01_build_models")

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--outcome", type = "character"),
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Helpers ----------------------------------------------------------------------
trim_glm <- function(model) {
  model$residuals <- NULL
  model$fitted.values <- NULL
  model$weights <- NULL
  model$prior.weights <- NULL

  return(model)
}
# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
subscores_dir <- file.path(
  paths$modeling$oos, "subscores", arg_list$split, arg_list$restriction
)
models_dir <- file.path(
  paths$modeling$oos, "models", arg_list$split, arg_list$restriction
)

# Load Data --------------------------------------------------------------------
message("Loading data...")
config <- read_yaml(
  file.path(dir, "ensemble_config", glue("{arg_list$outcome}.yml"))
)
ids <- readRDS(glue(paths$modeling$train)) %>%
  select(all_of(unique(c("ed_enc_id", "test_010_day", config$target))))

# Prep for models --------------------------------------------------------------
message("Preparing modeling parameters...")
# use p (probability) or z (log) variables
# according to whether using logit or OLS
if (config$logit) {
  covariates <- map(
    config$components, ~ (
      str_replace(.x, "gbm__", "z__gbm__") %>%
      str_replace("lasso__", "z__lasso__")
    )
  ) %>% unlist()
} else {
  covariates <- map(
    config$components, ~ (
      str_replace(.x, "gbm__", "p__gbm__") %>%
      str_replace("lasso__", "p__lasso__")
    )
  ) %>% unlist()
}
if(arg_list$restriction != "all"){
  covariates <- glue("{covariates}__{arg_list$restriction}")
}
if (!config$constant) {
  covariates <- c(covariates, "-1")
}
model_formula <- reformulate(
  response = config$target,
  termlabels = covariates
)
model_family <- ifelse(config$logit, "binomial", "gaussian")

# Ensembling -------------------------------------------------------------------
for (i in 1:5) {
  message("Ensembling subscores for fold ", i, "...")
  subscores <- readRDS(file.path(subscores_dir, glue("subscores__{i}.rds")))
  ensemble_df <- subscores %>%
    filter(train_fold != i) %>%
     u$safe_left_join(ids) %>%
    setDT()
  assert("All encs in ensemble", all(ensemble_df$in_ensemble))

  keep_pop <- switch(config$population,
    all = rep(TRUE, nrow(ensemble_df)),
    tested = ensemble_df$test_010_day == TRUE,
    untested = ensemble_df$test_010_day == FALSE
  )
  keep_pop <- keep_pop & !ensemble_df$excl_flag_chronic & !ensemble_df$excl_flag_c_int & !ensemble_df$excl_flag_chronic
  train_df <- ensemble_df[keep_pop, ]

  model <- glm(model_formula,
    data = train_df, family = model_family,
    model = FALSE, y = FALSE,
    control = list(maxit = 50, trace = TRUE)
  )
  model <- trim_glm(model)

  fp <- file.path(models_dir, glue("ensemble__{arg_list$outcome}__{i}.rds"))
  message("Saving full model for fold ", i, " to ", fp, "...")
  write_rds(model, fp)
}

message("Done.")
